/***************************************************************************************************
 * Copyright (c) 2017-2021, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Unit tests for epilogues
*/
#pragma once

#include <fstream>

#include "../../common/cutlass_unit_test.h"

#include "cutlass/aligned_buffer.h"
#include "cutlass/half.h"
#include "cutlass/complex.h"

#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"

#include "cutlass/util/host_tensor_planar_complex.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace test {
namespace kernel {

template <typename Epilogue>
__global__ void epilogue_planar_complex_threadblock(
        typename Epilogue::OutputTileIterator::Params params_D,
        typename Epilogue::OutputTileIterator::Element* ptr_D,
        int64_t imaginary_stride_D,
        typename Epilogue::OutputTileIterator::Params params_C,
        typename Epilogue::OutputTileIterator::Element* ptr_C,
        int64_t imaginary_stride_C,
        typename Epilogue::OutputOp::Params params_output_op,
        cutlass::MatrixCoord problem_size,
        cutlass::TensorRef<typename Epilogue::WarpMmaOperator::ElementC,
                           typename Epilogue::WarpMmaOperator::LayoutC>
                accumulator_ref,
        int64_t imaginary_stride_accum, int epilogue_count = 1) {
    __shared__ typename Epilogue::SharedStorage shared_storage;

    int thread_idx = threadIdx.x;
    int warp_idx = threadIdx.x / 32;
    int lane_idx = threadIdx.x % 32;

    //
    // Construct the epilogue
    //

    // Tile iterator writing to output tile
    typename Epilogue::OutputTileIterator iterator_D_real(
            params_D, ptr_D, problem_size, thread_idx);

    typename Epilogue::OutputTileIterator iterator_D_imag(
            params_D, ptr_D + imaginary_stride_D, problem_size, thread_idx);

    // Tile iterator writing to output tile
    typename Epilogue::OutputTileIterator iterator_C_real(
            params_C, ptr_C, problem_size, thread_idx);

    typename Epilogue::OutputTileIterator iterator_C_imag(
            params_C, ptr_C + imaginary_stride_C, problem_size, thread_idx);

    // Epilogue operator
    Epilogue epilogue(shared_storage, thread_idx, warp_idx, lane_idx);

    //
    // Initialize the accumulators
    //

    int warp_mn =
            warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN);
    int warp_m = warp_mn % Epilogue::WarpCount::kM;
    int warp_n = warp_mn / Epilogue::WarpCount::kM;

    accumulator_ref.add_coord_offset(
            {warp_m * Epilogue::WarpMmaOperator::Shape::kM,
             warp_n * Epilogue::WarpMmaOperator::Shape::kN});

    //
    // Load accumulators
    //

    typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(
            accumulator_ref, lane_idx);

    typename Epilogue::AccumulatorTile accumulators;

    accumulators.clear();

    accumulator_iterator.load(accumulators.real);
    accumulator_iterator.load_with_pointer_offset(accumulators.imag,
                                                  imaginary_stride_accum);

    //
    // Perform the epilogue operation
    //

    typename Epilogue::OutputOp output_op(params_output_op);

    // Place the epilogue in a loop so assembly is clearly visible
    for (int iter = 0; iter < epilogue_count; ++iter) {
        epilogue(output_op, iterator_D_real, iterator_D_imag, accumulators,
                 iterator_C_real, iterator_C_imag);
    }
}

}  // namespace kernel
}  // namespace test

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Epilogue_>
class EpiloguePlanarComplexTestbed {
public:
    using Epilogue = Epilogue_;
    using ElementAccumulator = typename Epilogue::ElementAccumulator;
    using ElementCompute = typename Epilogue::OutputOp::ElementCompute;
    using ElementOutput = typename Epilogue::ElementOutput;
    using OutputOpParams = typename Epilogue::OutputOp::Params;

    using ComplexElementOutput = cutlass::complex<ElementOutput>;
    using ComplexElementAccumulator = cutlass::complex<ElementAccumulator>;
    using ComplexElementCompute = cutlass::complex<ElementCompute>;

public:
    //
    // Data members
    //

    cutlass::MatrixCoord quantized_size;
    cutlass::HostTensorPlanarComplex<ElementAccumulator,
                                     cutlass::layout::RowMajor>
            accumulator_tensor;
    cutlass::HostTensorPlanarComplex<ElementOutput, cutlass::layout::RowMajor>
            source_tensor;
    cutlass::HostTensorPlanarComplex<ElementOutput, cutlass::layout::RowMajor>
            output_tensor;

public:
    //
    // Methods
    //

    EpiloguePlanarComplexTestbed()
            : quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN),
              accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}),
              source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}),
              output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}) {
        //
        // Initialize problem space
        //

#if 1
        uint64_t seed = 2019;

        cutlass::reference::host::TensorFillRandomUniform(
                accumulator_tensor.host_view(), seed, 20, -20, 0);

        cutlass::reference::host::TensorFillRandomUniform(
                source_tensor.host_view(), seed + 2018, 20, -20, 0);
#else

        cutlass::reference::host::BlockFillSequential(
                accumulator_tensor.host_data(), accumulator_tensor.capacity());

#endif
    }

    bool run_all() {
        cutlass::complex<float> alpha_values[3];

        alpha_values[0] = cutlass::complex<float>(1, 0);
        alpha_values[1] = cutlass::complex<float>(0, 0);
        alpha_values[2] = cutlass::complex<float>(2.25f, -0.5f);

        cutlass::complex<float> beta_values[3];

        beta_values[0] = cutlass::complex<float>(0, 0);
        beta_values[1] = cutlass::complex<float>(1, 0);
        beta_values[2] = cutlass::complex<float>(0.5f, -2.25f);

        // Test runtime explodes if we tried to test every case exhaustively.
        // This tests the full output tile and several smaller sizes to stress
        // predication.
        for (int m_idx = 0; m_idx < 3; ++m_idx) {
            for (int n_idx = 0; n_idx < 3; ++n_idx) {
                cutlass::MatrixCoord problem_size(
                        quantized_size.row() - m_idx * 3,
                        quantized_size.column() -
                                n_idx * Epilogue::kElementsPerAccess);

                for (auto const& alpha : alpha_values) {
                    for (auto const& beta : beta_values) {
                        bool passed = run(problem_size, {alpha, beta});

                        if (!passed) {
                            return false;
                        }
                    }
                }
            }
        }

        return true;
    }

    /// Runs the test
    bool run(cutlass::MatrixCoord problem_size, OutputOpParams output_params) {
        //
        // Initialize problem space
        //

        ComplexElementOutput default_output =
                ComplexElementOutput(ElementOutput(-127), ElementOutput(-101));

        cutlass::reference::host::TensorFill(output_tensor.host_view(),
                                             default_output);

        accumulator_tensor.sync_device();
        output_tensor.sync_device();
        source_tensor.sync_device();

        //
        // Initialize epilogue parameters
        //

        typename Epilogue::OutputTileIterator::Params params_D(
                output_tensor.layout());
        typename Epilogue::OutputTileIterator::Params params_C(
                source_tensor.layout());

        //
        // Launch kernel
        //

        dim3 grid(1, 1);
        dim3 block(Epilogue::WarpCount::kCount * 32, 1);

        test::kernel::epilogue_planar_complex_threadblock<Epilogue>
                <<<grid, block>>>(params_D, output_tensor.device_data(),
                                  output_tensor.imaginary_stride(), params_C,
                                  source_tensor.device_data(),
                                  source_tensor.imaginary_stride(),
                                  output_params, problem_size,
                                  accumulator_tensor.device_view_real(),
                                  accumulator_tensor.imaginary_stride());

        cudaError_t result = cudaDeviceSynchronize();

        if (result != cudaSuccess) {
            std::cerr << "Kernel error: " << cudaGetErrorString(result)
                      << std::endl;
            return false;
        }

        //
        // Verify results
        //
        output_tensor.sync_host();

        int errors = 0;
        int const kMaxErrors = 5;

        for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) {
            for (int c = 0; errors < kMaxErrors && c < quantized_size.column();
                 ++c) {
                cutlass::MatrixCoord coord{r, c};
                ComplexElementOutput got = output_tensor.at(coord);

                ComplexElementOutput expected = default_output;

                if (coord.row() < problem_size.row() &&
                    coord.column() < problem_size.column()) {
                    ComplexElementOutput src = source_tensor.at(coord);

                    ComplexElementCompute tmp =
                            output_params.alpha *
                                    ComplexElementCompute(
                                            accumulator_tensor.at(coord)) +
                            output_params.beta *
                                    ComplexElementCompute(src.real(),
                                                          src.imag());

                    expected = ComplexElementOutput(ElementOutput(tmp.real()),
                                                    ElementOutput(tmp.imag()));
                }

                if (expected != got) {
                    using OutputIO = cutlass::ScalarIO<ComplexElementOutput>;

                    EXPECT_TRUE(false)
                            << "-------\n"
                            << "Error - output element (" << coord
                            << ") - expected: " << OutputIO(expected)
                            << ",  got: " << OutputIO(got) << std::endl;

                    ++errors;
                }
            }
        }

        //
        // Report results on error
        //

        if (errors) {
            std::cout << "Incorrect result for problem(" << problem_size.row()
                      << ", " << problem_size.column()
                      << ") for alpha: " << output_params.alpha
                      << ", beta: " << output_params.beta << std::endl;

            std::stringstream ss;
            ss << "output_tensor_op_" << Epilogue::Shape::kM << "x"
               << Epilogue::Shape::kN << "_"
               << Epilogue::WarpTileIterator::WarpShape::kM << "x"
               << Epilogue::WarpTileIterator::WarpShape::kN << "_slice_"
               << Epilogue::WarpCount::kK << ".csv";

            std::ofstream output_file(ss.str());
            output_file << output_tensor.host_view();

            std::cout << "Wrote workspace to '" << ss.str() << "'" << std::endl;
        }

        return !errors;
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////
